Hi guys, so in this lecture I wanted to talk about speed in Python. Now, as beginners I want to emphasize that you guys should be focusing on writing correct programs and not pay too much mind to how fast they run. But with that said I do think that thinking about the speed of code can actually be fun; its problem solving and a challenge.
The 'joy of fast cars' is a somewhat cryptic title, but the explanation of it is fairly straight-forward; one of the things I find a lot of fun is trying to make my code more efficient; I genuinely enjoy the process of taking a bit of code and trying to come up with ways to make it faster. In my mind programming is at its most interesting when you can can look past stuff like language syntax and instead focus of the very nature of the problem itself. The aim of today is to try to get you a glimpse of that.
The problem we shall be looking at today is the following:
How can we list all of the prime numbers from 0 to N?
Lets start the process by spliting the process into two parts; first, we need a way of knowing if a number is prime or not. And once we have that, we need to check all the numbers 0 to N.
In [2]:
# Attempt 1
def is_prime(num):
"""Returns True if number is prime, False otherwise"""
if num <= 1: return False # negetive numbers are not prime
# check for factors
for i in range(2,num): # for loop that iterates 2-to-num. Each number in the iteration is called "i"
if (num % i) == 0: # modular arithmetic; this asks if num is divisible by i (with no remainder).
return False
# If we have iterated through every number upto num without finding a divisor it must be prime.
return True
# Making the list:
def get_primes(b):
primes = []
for num in range(0, b+1):
if is_prime(num): # Yes, you can call functions inside other functions!
primes.append(num) # If prime, add it to the list
return primes
print(get_primes(400))
Okay cool, now that we have working solution the next question is how to improve the speed. To improve speed, the first and most obvious starting point is to time the code. Lets do that now...
In [4]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(get_primes, 100_000) # Get all primes, 0..100,000
print(result)
In the above output, we can see that it took about 32secs to get all the primes less than 100,000.
We can also see that almost all of the time is taken by the "is_prime" function (tottime column). This is important, this means that improving the performance of "is_prime" is going to considerably increase performance, whereas improving the speed of the get_primes function will have almost no impact.
For example, list comprehensions are faster than for-loops and so I could speed the ‘get_primes’ function by doing that. For arguments sake lets suppose using a list comprehension can speed it up that function by a massive 90%!! That’s a huge improvement! But when we look at the total time we see that get_primes took 0.015 seconds on my machine. So a 90% improvement would speed us up by about 0.012 seconds.
Okay, so the function we need to improve is the ‘is_prime’ function. I think the first line of code to study is the for-loop:
for i in range(2, num):
So this is where the fun begins! To solve this puzzle we need to think logically and be a bit creative here. Improving the speed of this line of code is not simply “know more Python”. Rather, we need to think logically and apply a splash mathematics. Here, let me show you something:
In [5]:
for i in range(1, 20):
print(i, "--", 20/i)
So this code is dividing the number 20 by 'i', where 'i' is 1-to-20. The salient point here is that numbers past 11 are not whole numbers. This makes a lot of sense when you think about it; the minimum number of ‘integer parts’ we can split X into (besides 1) is two. Thus, when we start looking at numbers greater than n/2 the solution will never be a whole number. And that stands for all numbers, not just 20.
Now, we can we use this information to make our prime search smarter. As things currently stand proving 1499 is prime requires about that many steps; our code is (at the moment) asking if numbers like 1001, 1002, 1003, ... are divisors of 1499 but as the above logic demonstrates since we want to find divisors of n these checks are actually unnecessary. So, If we stop iterating at the number 750 we can approximately half the time it takes to find a prime number and still have a correct solution.
for i in range(2, num//2):
As a quick note, we are using integer division here because the range function cannot handle floats. Now, before we run the benchmark though, we need to check for correctness; whenever you make a changes, even If it is a small one, you should test it on a few inputs. We want to check we haven’t broken anything with our change (more on ‘regression testing’ later). With this in mind, I ran the following code on my machine (where is_prime is the old function and is_prime2 is with the change):
x = [i for i in range(0,30000) if is_prime(i)]
y = [i for i in range(0,30000) if is_prime2(i)]
print(x == y) ---> False
We have a bug batman! What went wrong? To find out, I ran the following bit of code:
x2 = set(x)
y2 = set(y)
x2.symmetric_difference(y2) ---> {4}
I converted the lists to sets because sets have this handy method for quickly telling the difference between two items. It turns out we have two lists, each with 3200+ numbers and the only difference is that one of these lists contains the number 4 and the other does not. So what’s the problem?
Well, our new function uses:
range(2, n//2)
and:
4//2 == 2
In short, our change to the function works great for large input but breaks for tiny inputs. I think the simplest fix to this problem is to use (n//2)+1, which should fix our error with a insignificant performance cost.
In [6]:
n = 4
for i in range(2, n//2):
print("(1)...", i)
# Nothing is printed! WTF!!
# Okay, attempted fix:
for i in range(2, n//2+1):
print("(2)..." , i)
In [9]:
# Attempt 2
def is_prime2(num):
"""Returns True if number is prime, False otherwise"""
if num <= 1: return False
# check for factors
for i in range(2, (num//2) + 1): ## tweaked
if (num % i) == 0:
return False
return True
# Making the list:
def get_primes2(b):
primes = []
for num in range(0, b+1):
if is_prime2(num): ## call our new prime function...
primes.append(num)
return primes
In [10]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(get_primes2, 100_000) # Get all primes, 0..100,000
print(result)
So this small change has roughly halved the amount of time it to get all primes upto 100,000.
Are we done? Well actually I can think of a few more tweaks...
Let's think about the nature of primes for one moment. The definition of a prime is that it is only divisible by itself and 1. And since an even number is, by definition, divisible by 2 we know that the only prime that is even is 2.
Let’s think of a large odd number (not necessarily prime). Our code is going to ask if 2,4,8,10,12… are divisors. But from the definition of even numbers we know that if 12, 18, 22, etc are divisors of X then so must 2. Which therefore means if 2 is not a divisor then neither is 6,8,100,102, etc.
In short, checking for 2 is equivalent to checking for all even numbers. Can we apply this insight to our code? I think so:
In [11]:
# Attempt 3
def is_prime3(num):
"""Returns True if number is prime, False otherwise"""
if num <= 1:
return False
if num == 2:
return True
if num % 2 == 0:
# notice that this check occurs AFTER we check is num == 2.
return False
# check for factors
for i in range(3,num//2+1, 2): # range function starts at odd number with a step of 2.
if (num % i) == 0:
return False
return True
# Making the list:
def get_primes3(b):
primes = []
for num in range(0, b+1):
if is_prime3(num): ## call our new prime function...
primes.append(num)
return primes
So in the above code we check if a number is divisible by 2 just once. After that, we only check if odd numbers are divisors of n. This change looks like it roughly halves the search space. Okay, let's benchmark it!
In [12]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(get_primes3, 100_000) # Get all primes, 0..100,000
print(result)
So now we are down to 8secs.
Can we do better? Perhaps, but I’m out of ideas at this point. But hey, google is a treasure trove of information, I wonder if there is some other maths ‘trick’ out there we could use…
After googling, I found out that apparently we can use the square root of n! And here I’m going to reproduce a maths proof which I found here.
Imagine we have two numbers A, B such that A * B = N
Now there are three possible cases.
Notice that because A B is the same as B A cases 1 and 3 are equivelent. So that means we only need to check cases 1 and 2. If A = B then A B can be rewritten as A A. If A * A = N then thats the definition of square root.
Alright, how can we implement this. Well, it took a bit of testing, but eventually I came up with this line (after importing the math module, of course):
for num in range(3, math.ceil(math.sqrt(num))+1, 2):
Sqrt(n) in many cases is not a whole number in some cases and as discussed elsewhere range requires an integer. That’s where math.ceil comes in, it rounds n up to the next integer (eg. math.ceil(6.0003) ---> 7) I then add one to make sure small numbers like 3,4 are handled correctly.
How much faster do we think this function will be? Lets run it!
In [16]:
# Attempt 4
from math import sqrt, ceil
def is_prime4(num):
"""Returns True if number is prime, False otherwise"""
if num <= 1:
return False
if num == 2:
return True
if num % 2 == 0:
return False
# check for factors
for i in range(3,ceil(sqrt(num))+1, 2): # upto sqrt of N (always rounded up)
if (num % i) == 0:
return False
return True
# Making the list:
def get_primes4(b):
primes = []
for num in range(0, b+1):
if is_prime4(num): ## call our new prime function...
primes.append(num)
return primes
In [20]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(get_primes4, 100_000) # Get all primes, 0..100,000
print(result)
When we started, it took 32secs to get all the primes from 0 to 100_000 (on my pc). With the square root trick it now takes 0.15secs, thats a massive difference.
You might wonder why square root made such a big difference. Why is to do with how functions scale with input size. For example, If we skip all the even numbers less than N, and we skip all the numbers greater than N / 2 then we have reduced the search space to approximately N / 4. This is a linear improvement; if the number is 1,000,000 then we still need to check 250,000 numbers. For 2,000,000 we check 500,000 numbers, and so on.
What about square root? Well the square root of a 1,000,000 is 1000 and the sqrt of 2,000,000 is about 1414. So as you can see, these numbers are growing at a much slower rate. And thats why its considerably faster. For a more theoretical explanation, please google "Big O Notation".
Are we done?
So, our task is create lists of primes upto N and so far our strategy for improving performance is to reduce the number of divisors we need to check.
Lets think outside the box for a moment. Image that there is an empty swimming pool full of small plastic balls and equally sized lead balls. We want to sort them. How might we do it?
Well, we could jump into the ball and pick the balls up. If its heavy we put it to oneside. Maybe we could improve this process by getting a friend to help. Maybe we get fifty friends, and at some point we realise that the way to improve performance is to manage people better (i.e. teamwork), and so you go down a rabbit-hole of small incremental improvements.
But then a new idea comes along! Its a solution that doesn't need fifty people, in fact its a solution whose running time is independant of the number of people working on the task. Can you guess what the idea maybe?
Okay here it is; fill the pool with water, the lead balls will stay at the bottom but the plastic balls will float. Get a big net and voila! The balls are sorted.
How can we apply this analogy to the current problem? Well, right now we are sort of searching for 'needles in haystacks'. We ask is n is prime, and then ask if n + 1 is prime and so on. Thus far, our optimisation technique has been "wait a minute, we do not need to check every straw", there are some numbers we can just ignore.
But what if there was some other way?
After a bit of research, it seems like the "Sieve of Atkin" is the fastest known algorithm. But this algorithm is rather complex. Another approach is to use the "Sieve of Eratosthenes". This algorithm is easier to implement, it may also be faster than our current method.
This method uses a different trick, If you want a detailed explantion check the wiki article, but basically it boils down to this:
So, if N is 20:
2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20
2 is prime, so now we get rid of 2 2, 2 3, ..., 2 * 10
2, 3, _, 5, _, 7, _, 9, _, 11, _, 13, _, 15, _, 17, _, 19, _
The next number is 3 so we remove 3 2, 3 3, 3 4, ..., 3 6
2, 3, _, 5, _, 7, _, _, 11, _, 13, _, _, _, 17, _, 19
The next is 5, so we remove 5 2 ... 5 4
And the next after that is 7 so we remove 7 * 2.
And then since 11 * 2 is greater than N we can stop the process.
Alright, to that explains the algorithm (more or less) lets code it up!
In [35]:
def sieve_of_eratosthenes(n):
mark = [True] * (n)
mark[0] = False
mark[1] = False # 0 and 1 are not prime, so we set these values to false.
i = 2 # start at index 2 since 2 is the smallest prime.
while i < len(mark):
p = i
i += 1
if not mark[p]:
## Ignore non-primes
continue
multiplier = 2
while p * multiplier < n:
## Set all the multiples of P to false (since they cannot be prime).
mark[p * multiplier] = False
multiplier += 1
return [i for i in range(len(mark)) if mark[i]] ## If all True values in the 'mark' array
In [31]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(sieve_of_eratosthenes, 100_000) # Get all primes, 0..100,000
print(result)
Okay, so it seems faster; 0.072 is less than 0.153. However, when you see very small difference in time the result can be unreliable. Because background process and other things happening on the computer may be affecting the results. The simplest way to check that we really do have a genuine difference it to increase our input size. I know! Lets get all the prime numbers less than 10 million!
In [32]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(get_primes4, 10_000_000) # Get all primes, 0..10,000,000
print(result)
In [33]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
result = profile(sieve_of_eratosthenes, 10_000_000) # Get all primes, 0..10,000,000
print(result)
So when we check for all primes less than 10 million the sieve_of_eratosthenes really stands out; 9 secs versus 66 sec for our previous champion.
Anyway, the main lesson I want you to learn here is that optimisation can be thought of as being two seperate ideas; the first idea is low-level tinkering, in other words, we look at all the small details and see if we can save a byte or two of memory here or there. But then there is ‘high-level’ optimisation, and that is where we try to come up with an entirely new (and hopefully better) strategy for solving the problem.
In this lecture we started searching for needles in haystacks (checking if each number was a prime). We improved that by checking fewer straws. And then we thought of new idea, which was a bit like searching for needles by setting the haystack on fire; with the idea being that the only thing left will be the needles (i.e. the primes).
In [36]:
# Quick check to ensure all the functions return the correct answer..
a = get_primes(1000)
b = get_primes2(1000)
c = get_primes3(1000)
d = get_primes4(1000)
e = sieve_of_eratosthenes(1000)
a == b == c == d == e
Out[36]:
In this weeks (optional) homework, your task it to try and write a bit of code that is faster than my code. And there is going to be two basic ways to do it; you can get your hands dirty and try some low-level optimisation or you can ditch all that and favour a high-level approach.
Unlike most of the homeworks, this more about being clever than it is about understanding Python.
The Challenge: BEAT MY TIME!!
The below code will create a list of all ODD square numbers starting at 1 and ending at x. Example:
If x is 100, the squares are:
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
Of which, we only want the odd numbers:
[1, 9, 25, 49, 81]
A few hints...
Please study the code below. Your jump is to either make it faster by tinkering with it. Or alternatively you may wish to use your own algorithm.
In [ ]:
import math
# My code, this is the function to beat! How can you improve it?
def squares(x):
lst = []
for number in range(1, x+1):
square = math.sqrt(number) # We call the square_root function on the number.
if square.is_integer():
# is_integer is a float method that returns true if the the number can be represented as an integer.
# for example, 4.0 = True, 4.89 = False
if number % 2 != 0: # checks if number is odd.
lst.append(number)
return lst
print(squares(1000))
In [40]:
import sys
sys.path.append(".\misc") # Adding to sys.path allows us to find "profile_code.py"
from profile_code import profile
import math
def my_squares(x):
"""
X: an Int
function returns a list of all odd square numbers >= X
>>> my_squares(100)
[1, 9, 25, 49, 81]
"""
# YOUR CODE GOES HERE !!!.
# Note, don't change name of this function, if you do, I cant test it!
return hamster_squares(x) ## CHANGE ME!!!
##################################
# MY CODE, a.k.a THE CODE TO BEAT!
# Please do not change this!!!
def hamster_squares(x):
lst = []
for number in range(1, x+1):
square = math.sqrt(number)
if square.is_integer():
if number % 2 != 0:
lst.append(number)
return lst
################## THE CONTROL PANEL ################################
#####################################################################
verbose = True # set to False if you dont want the time on line-by-line basis.
X = 5000000
# Lower X if tests are taking too long on your machine.
# Raise X if you want higher accuracy.
#####################################################################
teacher = hamster_squares(10000)
student = my_squares(10000)
correct = None
# TEST 1: CORRECTNESS
if teacher == student:
print("CORRECTNESS TEST = PASSED")
correct = True
else:
print("CORRECTNESS TEST = FAILED", "NOW TRYING TO DEBUG...", sep="\n")
# here is a bit of code to help you find the problem(s)!
# returning a list?
if not isinstance(student, list):
print("... Try returning a list next time, not a bloody {} !".format(type(student)))
# too many/too few items?
elif len(teacher) != len(student):
print(".... Your list has {} items, it should have {} items".format(len(student), len(teacher)))
# small numbers correct?
elif student[:10] != teacher[:10]:
print("... Start of list incorrect.\nYOURS: {}\nEXPECTED: {}".format(student[:10], teacher[:10]))
# testing for same items. Note that this test DOES NOT take order into consideration.
else:
ts = set(teacher)
st = set(student)
diff = ts.symmetric_difference(st)
if diff:
print("... The lists contain different numbers, these are... \n {}".format(diff))
# SPEED TESTS ... (just ignore this code)
if correct:
print("...Now testing speed. Please, note, this may take a while...\n",
"Also, I'd advise a margin or error of about +- 0.2 seconds\n")
def string(i, func, detail):
i = i.split("\n")
s= "✿ Stats for {} function... \n{}".format(func, i[2])
if detail:
s = s + "\n" + "\n".join(i[3:-7]) + "\n"
return s
print("-------- Solution Comparision, where input size is {}. -------- \n".format(X))
hs = profile(hamster_squares, X)
print(string(hs, "Teacher's Squares", verbose))
ss = profile(my_squares, X)
print(string(ss, "'YOUR'", verbose))
In [ ]: